--- title: SRResNet: Photo-Realistic Single Image Super-Resolution using Generative Adversarial Network(CVPR2017, Oral) keywords: fastai sidebar: home_sidebar ---
{% raw %}
{% endraw %} {% raw %}
%reload_ext autoreload
%autoreload 2
%matplotlib inline
{% endraw %} {% raw %}
import os
import cv2
import numpy as np
import re
import random
from tqdm import tqdm
from matplotlib import pyplot as plt
import PIL
{% endraw %} {% raw %}
{% endraw %} {% raw %}
import sys
sys.path.append('..')
from superres.datasets import *
from superres.databunch import *
{% endraw %} {% raw %}
seed = 8610
random.seed(seed)
np.random.seed(seed)
{% endraw %}

model

{% raw %}
# https://qiita.com/pacifinapacific/items/ec338a500015ae8c33fe
{% endraw %} {% raw %}
#exort
class ResidualBlock(nn.Module):
    def __init__(self,input_channel):
        super(ResidualBlock,self).__init__()
        self.residualblock=nn.Sequential(
            nn.Conv2d(input_channel,input_channel,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(input_channel),
            nn.PReLU(),
            nn.Conv2d(input_channel,input_channel,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(input_channel))
        
    def forward(self,x):
        residual=self.residualblock(x)
        return x+residual
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Pixcelshuffer[source]

Pixcelshuffer(input_channel, r) :: Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class SRResNet[source]

SRResNet(image_size) :: Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

{% endraw %} {% raw %}
test_input=torch.ones(1,3,64,64)
g=SRResNet(64)
test_input=test_input.cuda()
g=g.cuda()
out=g(test_input)
print(out.size())
torch.Size([1, 3, 256, 256])
{% endraw %}

DataBunch

{% raw %}
train_hr = div2k_train_hr_crop_256
{% endraw %} {% raw %}
in_size = 64
out_size = 256
scale = 4
bs = 8
{% endraw %} {% raw %}
data = create_sr_databunch(train_hr, in_size=in_size, out_size=out_size, scale=scale, bs=bs, seed=seed)
print(data)
data.show_batch()
ImageDataBunch;

Train: LabelList (25245 items)
x: ImageImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Valid: LabelList (6311 items)
x: ImageImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Test: None
{% endraw %}

Training

{% raw %}
model = SRResNet(64)
loss_func = MSELossFlat()
metrics = [m_psnr, m_ssim]
learn = Learner(data, model, loss_func=loss_func, metrics=metrics)
learn.path = Path('.')
model_name = model.__class__.__name__
{% endraw %} {% raw %}
lr_find(learn)
learn.recorder.plot(suggestion=True)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Min numerical gradient: 3.98E-04
Min loss divided by 10: 3.63E-04
{% endraw %} {% raw %}
lr = 1e-3
lrs = slice(lr)
epoch = 3
pct_start = 0.3
wd = 1e-3
save_fname = model_name
{% endraw %} {% raw %}
callbacks = [ShowGraph(learn), SaveModelCallback(learn, name=save_fname)]
{% endraw %} {% raw %}
learn.fit_one_cycle(epoch, lrs, pct_start=pct_start, wd=wd, callbacks=callbacks)
epoch train_loss valid_loss m_psnr m_ssim time
0 0.106310 0.072490 26.315109 0.350450 09:39
1 0.088919 0.064380 28.715475 0.400532 09:41
2 0.084303 0.054596 32.866898 0.420791 09:40
Better model found at epoch 0 with valid_loss value: 0.07248980551958084.
Better model found at epoch 1 with valid_loss value: 0.06438024342060089.
Better model found at epoch 2 with valid_loss value: 0.054595768451690674.
{% endraw %} {% raw %}
learn.show_results()
{% endraw %}

Test

{% raw %}
test_hr = set14_hr
{% endraw %} {% raw %}
il_test_x = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=in_size, scale=4,))
il_test_y = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=out_size))
il_test_x_up = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=out_size, scale=4, sizeup=True))
{% endraw %} {% raw %}
sr_test_upscale(learn, il_test_x, il_test_y, il_test_x_up, model_name)
bicubic: PSNR:24.11,SSIM:0.7822
SRResNet:	 PSNR:23.81,SSIM:0.7822
{% endraw %}

Report

{% raw %}
model
SRResNet(
  (pre_layer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): PReLU(num_parameters=1)
  )
  (residual_layer): Sequential(
    (0): ResidualBlock(
      (residualblock): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResidualBlock(
      (residualblock): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): ResidualBlock(
      (residualblock): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): ResidualBlock(
      (residualblock): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (4): ResidualBlock(
      (residualblock): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (middle_layer): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pixcelshuffer_layer): Sequential(
    (0): Pixcelshuffer(
      (layer): Sequential(
        (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): PixelShuffle(upscale_factor=2)
        (2): PReLU(num_parameters=1)
      )
    )
    (1): Pixcelshuffer(
      (layer): Sequential(
        (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): PixelShuffle(upscale_factor=2)
        (2): PReLU(num_parameters=1)
      )
    )
    (2): Conv2d(64, 3, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  )
)
{% endraw %} {% raw %}
learn.summary()
SRResNet
======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
Conv2d               [64, 64, 64]         15,616     True      
______________________________________________________________________
PReLU                [64, 64, 64]         1          True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
PReLU                [64, 64, 64]         1          True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
PReLU                [64, 64, 64]         1          True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
PReLU                [64, 64, 64]         1          True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
PReLU                [64, 64, 64]         1          True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
PReLU                [64, 64, 64]         1          True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,928     True      
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [256, 64, 64]        147,712    True      
______________________________________________________________________
PixelShuffle         [64, 128, 128]       0          False     
______________________________________________________________________
PReLU                [64, 128, 128]       1          True      
______________________________________________________________________
Conv2d               [256, 128, 128]      147,712    True      
______________________________________________________________________
PixelShuffle         [64, 256, 256]       0          False     
______________________________________________________________________
PReLU                [64, 256, 256]       1          True      
______________________________________________________________________
Conv2d               [3, 256, 256]        15,555     True      
______________________________________________________________________

Total params: 734,219
Total trainable params: 734,219
Total non-trainable params: 0
Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99)
Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ 
Loss function : FlattenedLoss
======================================================================
Callbacks functions applied 
{% endraw %}